{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "170ea770-eb8f-4589-89c5-ee5c5a212b07",
   "metadata": {},
   "source": [
    "# Understand Vector Sampler\n",
    "\n",
    "`VectorSampler` is a utility to directly access and export your vectors stored in the `Index` into a `VectorCollection` being essentially a numpy array and a list of ids corresponding to rows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c13eb567-0431-4320-9e81-742fc17b59b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install superlinked==3.38.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a0f803c0-3900-4cc0-9ed5-75016d476aff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from superlinked.evaluation.vector_sampler import VectorSampler\n",
    "from superlinked.framework.common.schema.id_schema_object import IdField\n",
    "from superlinked.framework.common.schema.schema import schema\n",
    "from superlinked.framework.common.schema.schema_object import String\n",
    "from superlinked.framework.dsl.index.index import Index\n",
    "from superlinked.framework.dsl.space.categorical_similarity_space import (\n",
    "    CategoricalSimilaritySpace,\n",
    ")\n",
    "from superlinked.framework.dsl.space.text_similarity_space import TextSimilaritySpace\n",
    "\n",
    "from superlinked.framework.dsl.executor.in_memory.in_memory_executor import (\n",
    "    InMemoryExecutor,\n",
    ")\n",
    "from superlinked.framework.dsl.source.in_memory_source import InMemorySource\n",
    "\n",
    "pd.set_option(\"display.max_colwidth\", 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90b37173-87a5-411b-b0f5-536e6e6f87de",
   "metadata": {},
   "source": [
    "## Load data into Superlinked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0454ae3a-9e33-47a5-8483-9edc474d78a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "@schema\n",
    "class Paragraph:\n",
    "    id: IdField\n",
    "    body: String\n",
    "    category: String\n",
    "\n",
    "\n",
    "paragraph = Paragraph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e6fc5b3d-41d4-4faa-a7af-e602d5e65f10",
   "metadata": {},
   "outputs": [],
   "source": [
    "body_space = TextSimilaritySpace(\n",
    "    text=paragraph.body, model=\"sentence-transformers/all-mpnet-base-v2\"\n",
    ")\n",
    "category_space = CategoricalSimilaritySpace(\n",
    "    category_input=paragraph.category,\n",
    "    categories=[\"category-1\", \"category-2\", \"category-3\"],\n",
    ")\n",
    "\n",
    "paragraph_index = Index([body_space, category_space])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "92782d34-18a3-4280-8c91-fb5356ec90e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "source: InMemorySource = InMemorySource(paragraph)\n",
    "executor = InMemoryExecutor(sources=[source], indices=[paragraph_index])\n",
    "app = executor.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b5f6d018-0fac-43f9-9166-b31ae6fc7958",
   "metadata": {},
   "outputs": [],
   "source": [
    "source.put(\n",
    "    [\n",
    "        {\n",
    "            \"id\": \"paragraph-1\",\n",
    "            \"body\": \"Glorious animals live in the wilderness.\",\n",
    "            \"category\": \"category-2\",\n",
    "        },\n",
    "        {\n",
    "            \"id\": \"paragraph-2\",\n",
    "            \"body\": \"Growing computation power enables advancements in AI.\",\n",
    "            \"category\": \"category-3\",\n",
    "        },\n",
    "        {\n",
    "            \"id\": \"paragraph-3\",\n",
    "            \"body\": \"Processed foods are generally worse for your health than raw vegetables.\",\n",
    "            \"category\": \"category-1\",\n",
    "        },\n",
    "        {\n",
    "            \"id\": \"paragraph-4\",\n",
    "            \"body\": \"The fauna of distant places can suprise travelers.\",\n",
    "            \"category\": \"category-2\",\n",
    "        },\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46507517-9ab7-4a06-87ad-0847cea51f7a",
   "metadata": {},
   "source": [
    "## Using a Vector Sampler \n",
    "\n",
    "A `VectorSampler` object can be created by supplying it with a running `executor` instance, an `app`. Subsequently, vectors from indices can be exported into a `VectorCollection` object per schema. The collections can contain all vectors or can be filtered by (a list of) id(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "819b3243-c5b5-4e51-b09f-676b87401770",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_sampler = VectorSampler(app=app)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4224f061-2c33-422b-9b15-11d2436e2963",
   "metadata": {},
   "source": [
    "### Get a subset of vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a55750c9-2bf8-45ad-b7b1-c72972f67251",
   "metadata": {},
   "source": [
    "A `VectorCollection` object is essentially a numpy array (vectors) with shape `(num_entities, vector_dims)` and a corresponding `id_list` where `id_list[i]` is the id of `vectors[i, :]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "521d50c6-ef0c-42e9-b986-6d7431431189",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VectorCollection of 1 vector."
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "singular_vector_collection = vector_sampler.get_vectors_by_ids(\n",
    "    id_=\"paragraph-1\", index=paragraph_index, schema=paragraph\n",
    ")\n",
    "singular_vector_collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "381455a3-e481-4cb4-b1ac-704fe9fdcf2a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['paragraph-1']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "singular_vector_collection.id_list  # the id we requested"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bcb7b378-8aac-4277-aa17-1b75e4801452",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 772)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "singular_vector_collection.vectors.shape  # 1 vector, 768 dimensions for text embedding, 4 for categorical embedding (3 categories and other)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "771ffc55-e271-4e3a-85de-aa43d13aabde",
   "metadata": {},
   "source": [
    "### Get all vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c94e6e2c-5434-47f3-bd41-64e665ccb89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_collection = vector_sampler.get_all_vectors(\n",
    "    index=paragraph_index, schema=paragraph\n",
    ")  # return all vectors of a schema in an index\n",
    "id_list, vector_array = vector_collection.id_list, vector_collection.vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6024c377-4cea-4a03-961a-3b195994f9f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VectorCollection of 4 vectors."
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vector_collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8e640c87-85e5-4961-8901-cfab97c442e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['paragraph-1', 'paragraph-2', 'paragraph-3', 'paragraph-4']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id_list  # all 4 vector ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "57ebcdcf-9d8f-4eff-8332-349bf48ef889",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 772)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vector_array.shape  # 4 vector, 768 dimensions for text embedding, 4 for categorical embedding (3 categories and other)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "superlinked-py3.10",
   "language": "python",
   "name": "superlinked-py3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
